{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/experimental/bert-search-test/bertcomparison.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/experimental/bert-search-test/bertcomparison.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-s-3dmlx769J"
      },
      "source": [
        "# BERT search in Pinecone"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s8z_fNc18bgT"
      },
      "source": [
        "## **Dependencies**\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W2gLHDxp8LPg"
      },
      "outputs": [],
      "source": [
        "!pip install --quiet pandas\n",
        "!pip install --quiet progressbar2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_3IkNj1-Vf5t"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "import bz2\n",
        "import time\n",
        "import numpy\n",
        "import pandas as pd\n",
        "from typing import List\n",
        "from statistics import mean\n",
        "from progressbar import progressbar"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E3dYi6DH8eRj"
      },
      "source": [
        "## **Dataset**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uA8F_1O_FZW7"
      },
      "source": [
        "The dataset used in this notebook is the dbpedia dataset that contains full abstracts of Wikipedia articles, usually the first section.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gAKe4Xaq8U4D"
      },
      "source": [
        "Downloading the dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7oDRUMS7-z-B"
      },
      "outputs": [],
      "source": [
        "!rm long_abstracts_en.ttl.bz2\n",
        "!wget http://downloads.dbpedia.org/2016-10/core-i18n/en/long_abstracts_en.ttl.bz2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5BLVAYWSGX3M"
      },
      "source": [
        "We will be conducting a similar test as described in this blog post: [Speeding up BERT Search in Elasticsearch](https://towardsdatascience.com/speeding-up-bert-search-in-elasticsearch-750f1f34f455#e7c4-d62eca28921b). The code is avaliable on this link: https://github.com/DmitryKey/bert-solr-search.git."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-QbhcI2j8XhJ"
      },
      "source": [
        "**Parsing the dataset**\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4edpjS7dG6_I"
      },
      "source": [
        "We will be using the same method that was used for parsing the dataset in the blogpost. Original source of this method can be found on this link: https://github.com/DmitryKey/bert-solr-search/blob/master/src/data_utils.py"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jdFGeufDTmuc"
      },
      "outputs": [],
      "source": [
        "def parse_dbpedia_data(source_file, max_docs: int):\n",
        "    \"\"\"\n",
        "    Parses the input file of abstracts and returns an iterable\n",
        "    :param max_docs: maximum number of input documents to process; -1 for no limit\n",
        "    :param source_file: input file\n",
        "    :return: yields document by document to the consumer\n",
        "    \"\"\"\n",
        "    global VERBOSE\n",
        "    count = 0\n",
        "    max_tokens = 0\n",
        "\n",
        "    if -1 < max_docs < 50:\n",
        "        VERBOSE = True\n",
        "\n",
        "    percent = 0.1\n",
        "    bulk_size = (percent / 100) * max_docs\n",
        "\n",
        "    print(f\"bulk_size={bulk_size}\")\n",
        "\n",
        "    if bulk_size <= 0:\n",
        "        bulk_size = 1000\n",
        "\n",
        "    for line in source_file:\n",
        "        line = line.decode(\"utf-8\")\n",
        "\n",
        "        # skip commented out lines\n",
        "        comment_regex = '^#'\n",
        "        if re.search(comment_regex, line):\n",
        "            continue\n",
        "\n",
        "        token_size = len(line.split())\n",
        "        if token_size > max_tokens:\n",
        "            max_tokens = token_size\n",
        "\n",
        "        # skip lines with 20 tokens or less, because they tend to contain noise\n",
        "        # (this may vary in your dataset)\n",
        "        if token_size <= 20:\n",
        "            continue\n",
        "\n",
        "        first_url_regex = '^<([^\\>]+)>\\s*'\n",
        "\n",
        "        x = re.search(first_url_regex, line)\n",
        "        if x:\n",
        "            url = x.group(1)\n",
        "            # also remove the url from the string\n",
        "            line = re.sub(first_url_regex, '', line)\n",
        "        else:\n",
        "            url = ''\n",
        "\n",
        "        # remove the second url from the string: we don't need to capture it, because it is repetitive across\n",
        "        # all abstracts\n",
        "        second_url_regex = '^<[^\\>]+>\\s*'\n",
        "        line = re.sub(second_url_regex, '', line)\n",
        "\n",
        "        # remove some strange line ending, that occurs in many abstracts\n",
        "        language_at_ending_regex = '@en \\.\\n$'\n",
        "        line = re.sub(language_at_ending_regex, '', line)\n",
        "\n",
        "        # form the input object for this abstract\n",
        "        doc = {\n",
        "            \"_text_\": line,\n",
        "            \"url\": url,\n",
        "            \"id\": count+1\n",
        "        }\n",
        "\n",
        "        yield doc\n",
        "        count += 1\n",
        "\n",
        "        if count % bulk_size == 0:\n",
        "            print(f\"Processed {count} documents\", end=\"\\r\")\n",
        "\n",
        "        if count == max_docs:\n",
        "            break\n",
        "\n",
        "    source_file.close()\n",
        "    print(\"Maximum tokens observed per abstract: {}\".format(max_tokens))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Kehp5R4sewc"
      },
      "source": [
        "If you are experiencing an issue with RAM, lower the number of MAX_DOCS."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vpm3LXNqMLzK"
      },
      "outputs": [],
      "source": [
        "MAX_DOCS = 1000000\n",
        "\n",
        "source_file = bz2.BZ2File(\"long_abstracts_en.ttl.bz2\", \"r\")\n",
        "docs_iter = parse_dbpedia_data(source_file, MAX_DOCS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IrFGTLsJ83-U"
      },
      "source": [
        "**Creating a pandas dataframe**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f9hgBuz2PhXt"
      },
      "outputs": [],
      "source": [
        "id = []\n",
        "text = []\n",
        "\n",
        "for doc in docs_iter:\n",
        "    id.append(doc['id'])\n",
        "    text.append(doc['_text_'])\n",
        "\n",
        "data = pd.DataFrame({'id': id, 'text': text})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dmYfZ3pSmfPl"
      },
      "outputs": [],
      "source": [
        "data.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F5R4pW2l9BQs"
      },
      "source": [
        "**Generating embeddings using BERT**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uVgsz8k_i7pw"
      },
      "source": [
        "Generating embeddings is a time consuming process. Please use GPU or lower the number of MAX_DOCS. On Google Colab you should be expecting around 1.5 hours for 1M documents with GPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZFxtg2PSqUW8"
      },
      "outputs": [],
      "source": [
        "!pip install --quiet sentence_transformers==1.0.4\n",
        "!pip install --quiet tqdm==4.41.1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dExZaswNqUDJ"
      },
      "outputs": [],
      "source": [
        "from sentence_transformers import SentenceTransformer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1CvYEqNxZG0j"
      },
      "outputs": [],
      "source": [
        "# # expensive: downloads the model, creates embeddings\n",
        "import os\n",
        "import h5py\n",
        "\n",
        "#If embeddings not present, run inferencing and store vectors as h5 file\n",
        "if not os.path.exists('embeddings.h5'):\n",
        "    model = SentenceTransformer('bert-base-nli-mean-tokens')\n",
        "    sentence_embeddings = model.encode(text, show_progress_bar=True)\n",
        "    file = h5py.File(\"embeddings.h5\", \"w\")\n",
        "    file.create_dataset('bert', data=sentence_embeddings)\n",
        "    file.close()\n",
        "    \n",
        "#If embedding file present, load the vectors directly\n",
        "hf = h5py.File('embeddings.h5', 'r')\n",
        "vec_embeds = list(hf.get('bert'))\n",
        "\n",
        "#Add embeddings to DataFrame\n",
        "data['embeddings'] = pd.Series(vec_embeds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eZfJgh_U9NRs"
      },
      "source": [
        "## **Pinecone**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8QBR_GvF8UM2"
      },
      "outputs": [],
      "source": [
        "!pip install --quiet -U pinecone-client"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k8dxpp3tXwTN"
      },
      "outputs": [],
      "source": [
        "from pinecone import Pinecone"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TzW6Jg9v-mX4"
      },
      "outputs": [],
      "source": [
        "# load Pinecone API key\n",
        "api_key = 'YOUR_API_KEY'\n",
        "\n",
        "pinecone.init(api_key=api_key)\n",
        "\n",
        "index_name = 'bert-stats-test'\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FEGEid3bFIMl"
      },
      "source": [
        "[Get the Pinecone API key](https://www.pinecone.io/start/) if you don’t have one already."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wWhDohHr_c4r"
      },
      "outputs": [],
      "source": [
        "items_to_upload = data[['id', 'embeddings']]\n",
        "items_to_upload = [tuple(x) for x in items_to_upload.to_numpy()]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NqrKHscA7e_i"
      },
      "source": [
        "We are defining a variable which we will be using to query vectors in batches. The reason for this is to make our results comparable to the ones published in the blog. By querying in batches and then dividing the elapsed time with the same number in the end, we minimize the influence of the networking time."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ay6F_PfuJ6yl"
      },
      "outputs": [],
      "source": [
        "def upload_items(items_to_upload: List, batch_size: int) -> float:\n",
        "    print(f\"\\nUpserting {len(items_to_upload)} vectors...\")\n",
        "    start = time.perf_counter()\n",
        "    upsert_cursor = index.upsert(items=items_to_upload,batch_size=batch_size)\n",
        "    end = time.perf_counter()\n",
        "    return (end - start) / 60.0 # minutes\n",
        "\n",
        "def restart_service(index_name: str, shards: int, timeout: int = 300):\n",
        "    if index_name in pinecone.list_indexes().names():\n",
        "        pinecone.delete_index(index_name)\n",
        "    pinecone.create_index(index_name,metric='cosine', shards=shards)\n",
        "    index = pinecone.Index(index_name)\n",
        "    return index\n",
        "\n",
        "def query(test_items: List, index):\n",
        "    print(f\"\\nQuerying...\")\n",
        "    times = []\n",
        "    #For single queries, we pick 10 queries \n",
        "    for test_item in test_items[:10]:\n",
        "        start = time.perf_counter()\n",
        "        #test_item is an array of [id,vector]\n",
        "        query_results = index.query(queries=[test_item[1]],disable_progress_bar=True)              # querying vectors with top_k=10\n",
        "        end = time.perf_counter()\n",
        "        times.append((end-start))           \n",
        "    #For batch queries, we pick 100 vectors at perform 10 queries\n",
        "    print(f\"\\n Batch Querying...\")\n",
        "    batch_times = []\n",
        "    for i in range(0,10000, 1000):\n",
        "        start = time.perf_counter() \n",
        "        batch_items = test_items[i:i+1000]\n",
        "        vecs = [item[1] for item in batch_items]\n",
        "        query_results = index.query(queries=vecs,disable_progress_bar=True)\n",
        "        end = time.perf_counter()\n",
        "        batch_times.append((end-start))\n",
        "    return mean(times)*1000,mean(batch_times)*1000\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e7f9HYTh9seN"
      },
      "source": [
        "Testing uploading and querying"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ry4gFTSS-mr0"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 1000\n",
        "NUMBER_OF_DOCS = [10000, 100000, 200000, 400000, 600000, 800000, 1000000]\n",
        "\n",
        "\n",
        "upsert_times = {}                  \n",
        "query_times = {}\n",
        "batch_query_times = {}\n",
        "for doc_size in progressbar(NUMBER_OF_DOCS):\n",
        "    if doc_size > len(items_to_upload):\n",
        "        print(f\"There are no {doc_size} vectors to be uploaded.\")\n",
        "        continue\n",
        "    test_vectors = items_to_upload[:10000]\n",
        "    index = restart_service(index_name, shards=3)\n",
        "    time_for_upsert = upload_items(items_to_upload[:doc_size], BATCH_SIZE)\n",
        "    time_for_query,time_for_batch_query = query(test_vectors, index)\n",
        "    upsert_times[doc_size] = time_for_upsert\n",
        "    query_times[doc_size] = time_for_query\n",
        "    batch_query_times[doc_size] = time_for_batch_query"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SPZggBT39yUL"
      },
      "source": [
        "## **Displaying results**\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yk1KmShtfszG"
      },
      "outputs": [],
      "source": [
        "time_results = pd.DataFrame({\n",
        "    'number_of_docs': upsert_times.keys(),\n",
        "    'indexing_time(min)': upsert_times.values(),\n",
        "    'avg_search_speed(ms)': query_times.values(),\n",
        "    'avg_batch_search_speed(ms), batch_size=1000':batch_query_times.values()\n",
        "})\n",
        "time_results['index_size(mb)'] = (time_results['number_of_docs'] * len(items_to_upload[0][1]) * 32) / 8000000 # megabytes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qs1esevPftNY"
      },
      "outputs": [],
      "source": [
        "time_results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KRY_Ur6OxRom"
      },
      "outputs": [],
      "source": [
        "#For batch querying, sending 100 queries at once\n",
        "print(query_times)\n",
        "print(batch_query_times)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RaaXX2iTRaUg"
      },
      "outputs": [],
      "source": [
        "time_results.plot(x=\"number_of_docs\", y=[\"indexing_time(min)\"], kind=\"bar\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8CSjoflnRvqS"
      },
      "outputs": [],
      "source": [
        "time_results.plot(x=\"number_of_docs\", y=[\"avg_search_speed(ms)\"], kind=\"bar\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2VpB2RslxRon"
      },
      "outputs": [],
      "source": [
        "time_results.plot(x=\"number_of_docs\", y=[\"avg_batch_search_speed(ms), batch_size=1000\"], kind=\"bar\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JeRiSIbIxRon"
      },
      "source": [
        "## Estimating Search Speed Without Network Overhead \n",
        "\n",
        "It is difficult to look at the real speed of Pinecone's engine when we are querying a deployment across the cloud which included overheads like network, parsing, authentication etc. These overheads contribute a lot to the observed speeds rather than search speed itself.\n",
        "\n",
        "One way to estimate the speed of the engine is to look at the returned search speeds as :\n",
        "\n",
        "overhead + num_queries\\*search_speed_per_query \n",
        "\n",
        "The overhead will change with every call but it generally does not fluctuate too much,so for rough estimates we can assume this as a constant in our equation.\n",
        "\n",
        "For 1 million documents and assuming the overhead is constant: \n",
        "\n",
        "single query\n",
        "overhead + 1\\*search_speed_per_query = 35.68ms\n",
        "batched query\n",
        "overhead + 1000\\*search_speed_per_query = 7020.4ms\n",
        "\n",
        "Solving for search_speed_per_query\n",
        "\n",
        "999\\*search_speed_per_query = 7020.4-35.68\n",
        "\n",
        "search_speed_per_query = 6.98 ms\n",
        "\n",
        "Using this logic let's look at estimated search speed per query on Pinecone's engine"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C51QGBU9xRoo"
      },
      "outputs": [],
      "source": [
        "speed_per_query = {}\n",
        "for num_docs in query_times.keys():\n",
        "    batch_speed = batch_query_times[num_docs]\n",
        "    speed = query_times[num_docs]\n",
        "    speed_per_query[num_docs] = (batch_speed-speed)/999"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eQE8GbiDxRoo"
      },
      "outputs": [],
      "source": [
        "time_results['speed_per_query'] = speed_per_query.values()\n",
        "time_results.plot(x=\"number_of_docs\", y=[\"speed_per_query\"], kind=\"bar\")\n",
        "speed_per_query"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y0QF8-YJxRoo"
      },
      "source": [
        "## Calculating Recall\n",
        "\n",
        "It's important that while performing ANN search we maintain a good recall while speeding up search.\n",
        "We will calculate the rank-k@k recall by taking results from KNN (exact search) to be the ground truth.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SjGatACvxRoo"
      },
      "outputs": [],
      "source": [
        "#Create an exact index and upload items\n",
        "exact_index_name = 'exactsearch'\n",
        "# exact_index = pinecone.create_index(name=exact_index_name,metric='cosine',shards=3,engine_type='exact')\n",
        "exact_index = pinecone.Index(exact_index_name)\n",
        "upsert_cursor = exact_index.upsert(items=items_to_upload[:doc_size],batch_size=1000)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ocbgPPnLxRop"
      },
      "source": [
        "Once we have uploaded the same items on the exact index as well, we will queries both exact and approximate indexes to compare the results.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qjt3W6ygxRop"
      },
      "outputs": [],
      "source": [
        "import concurrent.futures\n",
        "index = pinecone.Index(name=index_name)\n",
        "NUM_TEST_QUERIES = 10\n",
        "with concurrent.futures.ThreadPoolExecutor() as executor: \n",
        "    approx_res = executor.map(lambda i: index.unary_query( test_vectors[i][1], top_k=100), range(NUM_TEST_QUERIES))  \n",
        "    \n",
        "exact_index = pinecone.Index(name=exact_index_name)\n",
        "with concurrent.futures.ThreadPoolExecutor() as executor:     \n",
        "    exact_res = executor.map(lambda i: exact_index.unary_query( test_vectors[i][1], top_k=100), range(NUM_TEST_QUERIES))  \n",
        "    "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XrsZSrUQxRop"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "def anns_recall(r_exact, r):\n",
        "    assert(len(r_exact.scores) == len(r.scores))\n",
        "    exact_rank_k_score = r_exact.scores[-1]\n",
        "    indicator = [s >= exact_rank_k_score for s in r.scores]\n",
        "    return sum(indicator) / len(indicator)\n",
        "\n",
        "\n",
        "def approx_loss(r_exact, r):\n",
        "    return np.quantile([ abs(ext_s - apprx_s) for ext_s, apprx_s in zip(r_exact.scores, r.scores)], 0.5)\n",
        "\n",
        "\n",
        "recalls = []\n",
        "a_loss = []\n",
        "for exact_r, r in zip(exact_res, approx_res):\n",
        "    recalls.append( anns_recall(exact_r, r) )\n",
        "    a_loss.append(approx_loss(exact_r, r))\n",
        "\n",
        "print(\"Accuracy results over 10 test queries:\")\n",
        "print(f\"The average recall @rank-k is {sum(recalls)/len(recalls)}\")\n",
        "print(f\"The median approximation loss is {np.quantile(a_loss, 0.5)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qgDaQOqdjkMx"
      },
      "source": [
        "Delete the indexes, once done."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lKLueGcHjjWs"
      },
      "outputs": [],
      "source": [
        "pinecone.delete_index(index_name)\n",
        "pinecone.delete_index(exact_index_name)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "machine_shape": "hm",
      "name": "bertcomparison.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "environment": {
      "name": "common-cu110.m71",
      "type": "gcloud",
      "uri": "gcr.io/deeplearning-platform-release/base-cu110:m71"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
